{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Text Generation with miniature GPT\n",
    "\n",
    "**Author:** [Apoorv Nandan](https://twitter.com/NandanApoorv)<br>\n",
    "**Date created:** 2020/05/29<br>\n",
    "**Last modified:** 2020/05/29<br>\n",
    "**Description:** Implement miniature version of GPT and learn to generate text."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "This example demonstrates autoregressive language modelling using a\n",
    "a miniature version of GPT model.\n",
    "The model consists of a single transformer block with causal masking\n",
    "in the its attention layer.\n",
    "We use the text from IMDB sentiment classification dataset for training\n",
    "and generate new movie reviews for a given prompt.\n",
    "When using this script with your own data, make sure it has atleast\n",
    "1M words.\n",
    "\n",
    "This example should be run with `tf-nightly>=2.3.0-dev20200531` or\n",
    "with tensorflow 2.3 or higher.\n",
    "\n",
    "**References:**\n",
    "\n",
    "- [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035)\n",
    "- [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe)\n",
    "- [GPT-3](https://arxiv.org/abs/2005.14165)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n",
    "import numpy as np\n",
    "import os\n",
    "import re\n",
    "import string\n",
    "import random\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Self-attention with causal masking\n",
    "\n",
    "We compute self-attention as usual, but prevent any information to flow\n",
    "from future tokens by masking the upper half of the scaled dot product matrix.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class MultiHeadSelfAttention(layers.Layer):\n",
    "    def __init__(self, embed_dim, num_heads=8):\n",
    "        super(MultiHeadSelfAttention, self).__init__()\n",
    "        self.embed_dim = embed_dim\n",
    "        self.num_heads = num_heads\n",
    "        if embed_dim % num_heads != 0:\n",
    "            raise ValueError(\n",
    "                f\"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}\"\n",
    "            )\n",
    "        self.projection_dim = embed_dim // num_heads\n",
    "        self.query_dense = layers.Dense(embed_dim)\n",
    "        self.key_dense = layers.Dense(embed_dim)\n",
    "        self.value_dense = layers.Dense(embed_dim)\n",
    "        self.combine_heads = layers.Dense(embed_dim)\n",
    "\n",
    "    @staticmethod\n",
    "    def causal_attention_mask(n_dest, n_src, dtype):\n",
    "        \"\"\"\n",
    "        1's in the lower triangle, counting from the lower right corner.\n",
    "        \"\"\"\n",
    "        i = tf.range(n_dest)[:, None]\n",
    "        j = tf.range(n_src)\n",
    "        m = i >= j - n_src + n_dest\n",
    "        return tf.cast(m, dtype)\n",
    "\n",
    "    def attention(self, query, key, value):\n",
    "        score = tf.matmul(query, key, transpose_b=True)\n",
    "        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)\n",
    "        scaled_score = score / tf.math.sqrt(dim_key)\n",
    "\n",
    "        # prevent information flow from future tokens\n",
    "        shape = tf.shape(scaled_score)\n",
    "        dim_dest, dim_src = shape[2], shape[3]\n",
    "        attention_mask = self.causal_attention_mask(\n",
    "            dim_dest, dim_src, scaled_score.dtype\n",
    "        )\n",
    "        attention_mask = tf.reshape(attention_mask, [1, 1, dim_dest, dim_src])\n",
    "        scaled_score = scaled_score * attention_mask - 1e4 * (1 - attention_mask)\n",
    "\n",
    "        weights = tf.nn.softmax(scaled_score, axis=-1)\n",
    "        output = tf.matmul(weights, value)\n",
    "        return output, weights\n",
    "\n",
    "    def separate_heads(self, x, batch_size):\n",
    "        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))\n",
    "        return tf.transpose(x, perm=[0, 2, 1, 3])\n",
    "\n",
    "    def call(self, inputs):\n",
    "        # x.shape = [batch_size, seq_len, embedding_dim]\n",
    "        batch_size = tf.shape(inputs)[0]\n",
    "        query = self.query_dense(inputs)  # (batch_size, seq_len, embed_dim)\n",
    "        key = self.key_dense(inputs)  # (batch_size, seq_len, embed_dim)\n",
    "        value = self.value_dense(inputs)  # (batch_size, seq_len, embed_dim)\n",
    "        query = self.separate_heads(\n",
    "            query, batch_size\n",
    "        )  # (batch_size, num_heads, seq_len, projection_dim)\n",
    "        key = self.separate_heads(\n",
    "            key, batch_size\n",
    "        )  # (batch_size, num_heads, seq_len, projection_dim)\n",
    "        value = self.separate_heads(\n",
    "            value, batch_size\n",
    "        )  # (batch_size, num_heads, seq_len, projection_dim)\n",
    "        attention, weights = self.attention(query, key, value)\n",
    "        attention = tf.transpose(\n",
    "            attention, perm=[0, 2, 1, 3]\n",
    "        )  # (batch_size, seq_len, num_heads, projection_dim)\n",
    "        concat_attention = tf.reshape(\n",
    "            attention, (batch_size, -1, self.embed_dim)\n",
    "        )  # (batch_size, seq_len, embed_dim)\n",
    "        output = self.combine_heads(\n",
    "            concat_attention\n",
    "        )  # (batch_size, seq_len, embed_dim)\n",
    "        return output\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Implement a Transformer block as a layer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class TransformerBlock(layers.Layer):\n",
    "    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
    "        super(TransformerBlock, self).__init__()\n",
    "        self.att = MultiHeadSelfAttention(embed_dim, num_heads)\n",
    "        self.ffn = keras.Sequential(\n",
    "            [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
    "        )\n",
    "        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.dropout1 = layers.Dropout(rate)\n",
    "        self.dropout2 = layers.Dropout(rate)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        attention_output = self.att(inputs)\n",
    "        attention_output = self.dropout1(attention_output)\n",
    "        out1 = self.layernorm1(inputs + attention_output)\n",
    "        ffn_output = self.ffn(out1)\n",
    "        ffn_output = self.dropout2(ffn_output)\n",
    "        return self.layernorm2(out1 + ffn_output)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Implement embedding layer\n",
    "\n",
    "Two seperate embedding layers, one for tokens, one for token index (positions).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class TokenAndPositionEmbedding(layers.Layer):\n",
    "    def __init__(self, maxlen, vocab_size, emded_dim):\n",
    "        super(TokenAndPositionEmbedding, self).__init__()\n",
    "        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=emded_dim)\n",
    "        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=emded_dim)\n",
    "\n",
    "    def call(self, x):\n",
    "        maxlen = tf.shape(x)[-1]\n",
    "        positions = tf.range(start=0, limit=maxlen, delta=1)\n",
    "        positions = self.pos_emb(positions)\n",
    "        x = self.token_emb(x)\n",
    "        return x + positions\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Implement miniature GPT model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "vocab_size = 20000  # Only consider the top 20k words\n",
    "maxlen = 100  # Max sequence size\n",
    "embed_dim = 256  # Embedding size for each token\n",
    "num_heads = 2  # Number of attention heads\n",
    "feed_forward_dim = 256  # Hidden layer size in feed forward network inside transformer\n",
    "\n",
    "\n",
    "def create_model():\n",
    "    inputs = layers.Input(shape=(maxlen,), dtype=tf.int32)\n",
    "    embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)\n",
    "    x = embedding_layer(inputs)\n",
    "    transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)\n",
    "    x = transformer_block(x)\n",
    "    outputs = layers.Dense(vocab_size)(x)\n",
    "    model = keras.Model(inputs=inputs, outputs=[outputs, x])\n",
    "    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "    model.compile(\n",
    "        \"adam\", loss=[loss_fn, None],\n",
    "    )  # No loss and optimization based on word embeddings from transformer block\n",
    "    return model\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Prepare data for word level language modelling\n",
    "\n",
    "We will download IMDB data, and combine training and validation sets for\n",
    "text generation task.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n",
    "!tar -xf aclImdb_v1.tar.gz\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "batch_size = 32\n",
    "\n",
    "# The dataset contains each review in a separate text file\n",
    "# The text files are present in four different folders\n",
    "# Create a list all files\n",
    "filenames = []\n",
    "directories = [\n",
    "    \"aclImdb/train/pos\",\n",
    "    \"aclImdb/train/neg\",\n",
    "    \"aclImdb/test/pos\",\n",
    "    \"aclImdb/test/neg\",\n",
    "]\n",
    "for dir in directories:\n",
    "    for f in os.listdir(dir):\n",
    "        filenames.append(os.path.join(dir, f))\n",
    "\n",
    "print(f\"{len(filenames)} files\")\n",
    "\n",
    "# Create dataset from text files\n",
    "random.shuffle(filenames)\n",
    "text_ds = tf.data.TextLineDataset(filenames)\n",
    "text_ds = text_ds.shuffle(buffer_size=256)\n",
    "text_ds = text_ds.batch(batch_size)\n",
    "\n",
    "\n",
    "def custom_standardization(input_string):\n",
    "    \"\"\" Remove html line-break tags and handle punctuation \"\"\"\n",
    "    lowercased = tf.strings.lower(input_string)\n",
    "    stripped_html = tf.strings.regex_replace(lowercased, \"<br />\", \" \")\n",
    "    return tf.strings.regex_replace(stripped_html, f\"([{string.punctuation}])\", r\" \\1\")\n",
    "\n",
    "\n",
    "# Create vectcorization layer and adapt it to the text\n",
    "vectorize_layer = TextVectorization(\n",
    "    standardize=custom_standardization,\n",
    "    max_tokens=vocab_size - 1,\n",
    "    output_mode=\"int\",\n",
    "    output_sequence_length=maxlen + 1,\n",
    ")\n",
    "vectorize_layer.adapt(text_ds)\n",
    "vocab = vectorize_layer.get_vocabulary()  # To get words back from token indices\n",
    "\n",
    "\n",
    "def prepare_lm_inputs_labels(text):\n",
    "    \"\"\"\n",
    "    Shift word sequences by 1 position so that the target for position (i) is\n",
    "    word at position (i+1). The model will use all words up till position (i)\n",
    "    to predict the next word.\n",
    "    \"\"\"\n",
    "    text = tf.expand_dims(text, -1)\n",
    "    tokenized_sentences = vectorize_layer(text)\n",
    "    x = tokenized_sentences[:, :-1]\n",
    "    y = tokenized_sentences[:, 1:]\n",
    "    return x, y\n",
    "\n",
    "\n",
    "text_ds = text_ds.map(prepare_lm_inputs_labels)\n",
    "text_ds = text_ds.prefetch(tf.data.experimental.AUTOTUNE)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Callback for generating text\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class TextGenerator(keras.callbacks.Callback):\n",
    "    \"\"\"Callback to generate text from trained model.\n",
    "    1. Feed some starting prompt to the model\n",
    "    2. Predict probabilities for next token\n",
    "    3. Sample next token and add it to the next input\n",
    "\n",
    "    # Arguments\n",
    "        max_tokens: Integer, the number of tokens to be generated after prompt.\n",
    "        start_tokens: List of integers, the token indices for the starting prompt.\n",
    "        index_to_word: List of strings, obtained from TextVectorization layer.\n",
    "        top_k: Integer, sample from the `top_k` token predictions.\n",
    "        print_every: Integer, print after this many epochs.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1\n",
    "    ):\n",
    "        self.max_tokens = max_tokens\n",
    "        self.start_tokens = start_tokens\n",
    "        self.index_to_word = index_to_word\n",
    "        self.print_every = print_every\n",
    "        self.k = top_k\n",
    "\n",
    "    def sample_from(self, logits):\n",
    "        logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)\n",
    "        indices = np.asarray(indices).astype(\"int32\")\n",
    "        preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]\n",
    "        preds = np.asarray(preds).astype(\"float32\")\n",
    "        return np.random.choice(indices, p=preds)\n",
    "\n",
    "    def detokenize(self, number):\n",
    "        return self.index_to_word[number]\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        start_tokens = [_ for _ in self.start_tokens]\n",
    "        if (epoch + 1) % self.print_every != 0:\n",
    "            return\n",
    "        num_tokens_generated = 0\n",
    "        tokens_generated = []\n",
    "        while num_tokens_generated <= self.max_tokens:\n",
    "            pad_len = maxlen - len(start_tokens)\n",
    "            sample_index = len(start_tokens) - 1\n",
    "            if pad_len < 0:\n",
    "                x = start_tokens[:maxlen]\n",
    "                sample_index = maxlen - 1\n",
    "            elif pad_len > 0:\n",
    "                x = start_tokens + [0] * pad_len\n",
    "            else:\n",
    "                x = start_tokens\n",
    "            x = np.array([x])\n",
    "            y, _ = self.model.predict(x)\n",
    "            sample_token = self.sample_from(y[0][sample_index])\n",
    "            tokens_generated.append(sample_token)\n",
    "            start_tokens.append(sample_token)\n",
    "            num_tokens_generated = len(tokens_generated)\n",
    "        txt = \" \".join(\n",
    "            [self.detokenize(_) for _ in self.start_tokens + tokens_generated]\n",
    "        )\n",
    "        print(f\"generated text:\\n{txt}\\n\")\n",
    "\n",
    "\n",
    "# Tokenize starting prompt\n",
    "word_to_index = {}\n",
    "for index, word in enumerate(vocab):\n",
    "    word_to_index[word] = index\n",
    "\n",
    "start_prompt = \"this movie is\"\n",
    "start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]\n",
    "num_tokens_generated = 40\n",
    "text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train\n",
    "\n",
    "Note: This code should preferably be run on GPU.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model = create_model()\n",
    "\n",
    "model.fit(text_ds, verbose=2, epochs=30, callbacks=[text_gen_callback])\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "text_generation_with_miniature_gpt",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
